"""Wrapper for C API of POF."""
import abc
import os
from time import time
import ctypes
from ctypes import *
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import numpy as np
from numpy.ctypeslib import as_array
import torch
from torch import nn
import deepspeed

# from .utils import *

_ctypes_double_ptr = "ctypes._Pointer[ctypes.c_double]"

_ctypes_int_ptr = "ctypes._Pointer[ctypes.c_int32]"

def _is_1d_list(data: Any) -> bool:
    """Check whether data is a 1-D list."""
    return isinstance(data, list) and (not data or _is_numeric(data[0]))

def _is_numpy_1d_array(data: Any) -> bool:
    """Check whether data is a numpy 1-D array."""
    return isinstance(data, np.ndarray) and len(data.shape) == 1

def _convert_from_sliced_object(data: np.ndarray) -> np.ndarray:
    """Fix the memory of multi-dimensional sliced object."""
    if isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray):
        if not data.flags.c_contiguous:
            _log_warning(
                "Usage of np.ndarray subset (sliced data) is not recommended "
                "due to it will double the peak memory cost"
            )
            return np.copy(data)
    return data

def _c_double_array(data: np.ndarray) -> _ctypes_double_ptr:
    """Get pointer of float numpy array / list."""
    if _is_1d_list(data):
        data = np.array(data, copy=False)
    if _is_numpy_1d_array(data):
        data = _convert_from_sliced_object(data)
        assert data.flags.c_contiguous
        ptr_data: _ctypes_double_ptr
        if data.dtype == np.float32:
            ptr_data = data.astype(np.float64).ctypes.data_as(ctypes.POINTER(ctypes.c_double))
        elif data.dtype == np.float64:
            ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
        else:
            raise TypeError(f"Expected np.float32 or np.float64, met type({data.dtype})")
    else:
        raise TypeError(f"Unknown type({type(data).__name__})")
    return ptr_data  # return `data` to avoid the temporary copy is freed


def _c_int_array(data: np.ndarray) -> Tuple[_ctypes_int_ptr, int, np.ndarray]:
    """Get pointer of int numpy array / list."""
    if _is_1d_list(data):
        data = np.array(data, copy=False)
    if _is_numpy_1d_array(data):
        data = _convert_from_sliced_object(data)
        assert data.flags.c_contiguous
        ptr_data: _ctypes_int_ptr
        if data.dtype == np.int32:
            ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
        elif data.dtype == np.int64:
            ptr_data = data.astype(np.int32).ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
        else:
            raise TypeError(f"Expected np.int32 or np.int64, met type({data.dtype})")
    else:
        raise TypeError(f"Unknown type({type(data).__name__})")
    return ptr_data  # return `data` to avoid the temporary copy is freed

def _cfloat32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
    """Convert a ctypes float pointer array to a numpy array."""
    if isinstance(cptr, ctypes.POINTER(ctypes.c_float)):
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
    else:
        raise RuntimeError('Expected float pointer')


def _cfloat64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
    """Convert a ctypes double pointer array to a numpy array."""
    if isinstance(cptr, ctypes.POINTER(ctypes.c_double)):
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
    else:
        raise RuntimeError('Expected double pointer')


def _cint32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
    """Convert a ctypes int pointer array to a numpy array."""
    if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)):
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
    else:
        raise RuntimeError('Expected int32 pointer')


def _cint64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
    """Convert a ctypes int pointer array to a numpy array."""
    if isinstance(cptr, ctypes.POINTER(ctypes.c_int64)):
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
    else:
        raise RuntimeError('Expected int64 pointer')
    
def _load_lib() -> ctypes.CDLL:
    """Load POF library."""
    lib_path = os.path.abspath(__file__ + "/../lib_POF.so") # find_lib_path()
    lib = ctypes.cdll.LoadLibrary(lib_path)
    lib.POF_get_solution.restype = ctypes.c_int
    lib.POF_calc_approx_loss.restype = ctypes.c_double
    lib.POF_get_posterior.restype = ctypes.c_double
    lib.POF_get_bias.restype = ctypes.c_double
    lib.POF_update_intest_nndata.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(c_int32), POINTER(c_double), ctypes.c_void_p]

    return lib
    
_LIB = _load_lib()

class POFError(Exception):
    """Error thrown by LightGBM."""

    pass


def _safe_call(ret: int) -> None:
    """Check the return value from C API call.

    Parameters
    ----------
    ret : int
        The return value from C API calls.
    """
    if ret != 0:
        raise POFError(_LIB.POF_GetLastError())#.decode("utf-8"))

"""
User defined methods (cpp)
    - get_solution
    - approx_loss

USER_PARAMETER



Indata(cpp) struct
    -normaltable
    -normalclass
    -posterior

POF_Model
    attributes
    - Model (perception module, python)
    - Indata 

    methods
    - __init__:
        set_userparameter
        init_indata
    - get_solution(): solves the optimization problem
    - update_nndata(1, train_sample_index, train_internal_test_index_array): 
        updates the optimization problem & corresponding perception module output

    (Used for training)
	- set_normaltable(NN_DATA, USER_PARAMETER, &INDATA, 1) from _LIB
	- set_normalclass(NN_DATA, &INDATA) from _LIB
	- set_posterior(NN_DATA, USER_PARAMETER, &INDATA) from _LIB
"""
      
class BasePOFEngine:
    def __init__(
        self, 
        model: deepspeed.DeepSpeedEngine | torch.nn.Module,
        optimizer: Optional[torch.optim.Optimizer] = None,
        n_rank = 2,
        th = 0.1,
        beta = 1000.0,
        min_loss = 3000.0,
        lambda_param = 0.03,
        xi = 0.03,
        context = [[0.5, 0.5]],
        n_sample = 100,
        m_sample = 100,
        forward_logit = False
    ):
        """
        Option 1)
        model is deepspeed.DeepSpeedEngine which has following methods:
        model.forward(input) = model.__call__(input)
        model.backward(loss): backpropagation
        model.step(): optimizer step

        optimizer in None

        Option 2) 
        model is torch.nn.Module which has method
        model.forward(input)

        optimizer is torch.optim.Optimizer which is
        initialized with model parameters and has method
        optimizer.step()
        """

       

        self.n_class = 2
        self.module = model
        self.optimizer = optimizer
        self._user_parameters = ctypes.c_void_p()
        self._nndata = ctypes.c_void_p()
        self._indata = ctypes.c_void_p()

        context_list = np.array(context, dtype=np.float64).flatten()
        self.xi = xi
        self.min_loss = min_loss
        _LIB.POF_init_userparameter(ctypes.byref(self._user_parameters))
        _LIB.POF_set_userparameter(n_rank, c_double(th), c_double(beta), c_double(min_loss), c_double(lambda_param), c_double(xi), _c_double_array(context_list), self._user_parameters)

        _LIB.POF_init_nndata(ctypes.byref(self._nndata))
        _LIB.POF_init_indata(ctypes.byref(self._indata))

        self.approx_loss = None

        self.n_sample = n_sample
        self.m_sample = m_sample

        self.test_case_ccon_w = torch.Tensor([]).to(self.module.device)
        self.ans = torch.Tensor([]).to(self.module.device)

        self.ccon_grad = None
        self.test_case_ccon_grad = None

        self.test_case_ccon_grad_list = []
        self.ans_list = []

        self.forward_logit = forward_logit

    def update_xi(self, xi):
        self.xi = xi
        _LIB.POF_update_xi(c_double(xi), self._user_parameters)
        print("Updated xi to: ", xi)

    def update_min_loss(self, min_loss):
        self.min_loss = min_loss
        _LIB.POF_update_min_loss(c_double(min_loss), self._user_parameters)
        print("Updated min_loss to: ", min_loss)

    def update_indata(self, set_bias=True):
        n_internal_test_case = len(self.ans)
        ans_list = _c_int_array(self.ans.numpy(force=True).astype(np.int32).flatten())
        test_case_ccon_w_list = _c_double_array(self.test_case_ccon_w.double().numpy(force=True).flatten())
        _LIB.POF_update_intest_nndata(n_internal_test_case, self.n_class, ans_list, test_case_ccon_w_list, self._nndata)
        if set_bias:
            print("Setting bias")
            _LIB.POF_set_bias(1, self._nndata, self._user_parameters, self._indata)
        else:
            _LIB.POF_update_indata(1, self._nndata, self._user_parameters, self._indata)

    def _update_traindata(self, ccon_w, objective):
        self.ccon_w = ccon_w
        batch_size = len(objective)
        objective_list = _c_double_array(objective.double().numpy(force=True).flatten())
        ccon_w_list = _c_double_array(self.ccon_w.double().numpy(force=True).flatten())
        _LIB.POF_update_train_nndata(batch_size, self.n_class, objective_list, ccon_w_list, self._nndata)

    def train_forward(self, input, objective, ans, *args, **kwargs): # override for customized inputs
        objective = torch.randn_like(objective)/1000.0
        
        batch_size = len(objective)
        n_internal_test_case = len(ans)
        assert batch_size + n_internal_test_case == len(input)

        output = self.module(input, *args, **kwargs).end_scores
        # ans: 1 for safe / 0 for unsafe
        self.intest_acc = ((output[-n_internal_test_case:] < 0).flatten().int() == ans).double().mean()
        self.intest_FPR = torch.where(output[-n_internal_test_case:].flatten() < 0, 1 - ans, 0.).sum() / (output[-n_internal_test_case:] < 0).int().sum()
        self.intest_FNR = torch.where(output[-n_internal_test_case:].flatten() > 0,  ans, 0.).sum() / (output[-n_internal_test_case:] > 0).int().sum()

        # 1 for safe / 0 for unsafe
        output = torch.cat((output, torch.zeros(output.shape, device = output.device)), -1)
        ccon_w = output[:batch_size]
        self.test_case_ccon_w = output[-n_internal_test_case:]
        self.ans = ans
        self.ans_list.append(ans)
        self.ans_list = self.ans_list[-10:]
        # self.test_case_ccon_w = torch.zeros_like(self.test_case_ccon_w) + torch.Tensor([[1., 0.]]*len(self.test_case_ccon_w)).to(self.test_case_ccon_w.device)

        self.update_indata(set_bias=False)
        self._update_traindata(ccon_w, objective)
        
        solution = _LIB.POF_get_solution(self._nndata, self._indata, self._user_parameters)
        approx_loss = _LIB.POF_calc_approx_loss(self._nndata, self._indata, self._user_parameters)
        self.approx_loss = torch.Tensor([approx_loss]).to(input.device)

        print("Solution: ", solution)
        return solution

    def add_indata(self, intest_input, ans, *args, change = False, **kwargs):
        assert len(intest_input) == len(ans)
        output = self.module(intest_input, *args, **kwargs).end_scores.detach().cpu()
        ans = ans.cpu()
        # ans: 1 for safe / 0 for unsafe
        self.intest_acc = (output.flatten().int() == ans).double().mean()
        self.intest_FPR = torch.where(output.flatten() < 0, 1 - ans, 0.).sum() / (output < 0).int().sum()
        self.intest_FNR = torch.where(output.flatten() > 0,  ans, 0.).sum() / (output > 0).int().sum()

        # 1 for safe / 0 for unsafe
        test_case_ccon_w = torch.cat((output, torch.zeros(output.shape, device = output.device)), -1)
        if change:
            self.test_case_ccon_w = test_case_ccon_w
            self.ans = ans
        else:
            self.test_case_ccon_w = torch.cat((self.test_case_ccon_w.cpu(), test_case_ccon_w), 0)
            self.ans = torch.cat((self.ans.cpu(), ans), 0)
        


    def inference_forward(self, inference_input, objective, *args, **kwargs): # override for customized inputs
        assert len(inference_input) == len(objective)
        objective = torch.randn_like(objective)/1000.0

        batch_size = len(inference_input)
        output = self.module(inference_input, *args, **kwargs).end_scores
        
        # 1 for safe / 0 for unsafe
        ccon_w = torch.cat((output, torch.zeros(output.shape, device = output.device)), -1)
        self._update_traindata(ccon_w, objective)
        
        solution = _LIB.POF_get_solution(self._nndata, self._indata, self._user_parameters)
        approx_loss = _LIB.POF_calc_approx_loss(self._nndata, self._indata, self._user_parameters)
        self.approx_loss = torch.Tensor([approx_loss]).to(inference_input.device)

        print("Solution: ", solution)
        return solution


    def get_logits(self, input, *args, **kwargs):
        output = self.module(input, *args, **kwargs).end_scores
        bias = _LIB.POF_get_bias(self._nndata)

        return output + bias

    def loss(self) -> (torch.Tensor, torch.Tensor): 

        ccon_w_len = np.prod(list(self.ccon_w.shape))
        ccon_gradient_list = (c_double*ccon_w_len)()  # ctypes.POINTER(ctypes.c_double)()
        ccon_est_list = (c_double*ccon_w_len)()
        ccon_est_square_list = (c_double*ccon_w_len)()
        test_case_ccon_w_len = np.prod(list(self.test_case_ccon_w.shape))
        test_case_ccon_gradient_list = (c_double*test_case_ccon_w_len)()
        test_case_ccon_gradient_approx_list = (c_double*test_case_ccon_w_len)()
        test_case_ccon_vpd_list = (c_double*test_case_ccon_w_len)()
        test_case_ccon_est_list = (c_double*test_case_ccon_w_len)()
        test_case_ccon_est_square_list = (c_double*test_case_ccon_w_len)()
        
        start_time = time()
        _safe_call(_LIB.POF_fullcalc_grad_ccon(self._nndata, self._user_parameters, self._indata, ccon_gradient_list)) #, ccon_gradient_square_list, self.n_sample, self.m_sample)) # 1d array
        _safe_call(_LIB.POF_fullcalc_approx_grad_test_case_ccon(self._nndata, self._user_parameters, self._indata, test_case_ccon_gradient_list)) #, test_case_ccon_gradient_square_list, self.n_sample, self.m_sample))
        print("Time for fullcalc: ", time() - start_time)

        # start_time = time()
        # _safe_call(_LIB.POF_calc_grad_test_case_ccon(self._nndata, self._user_parameters, self._indata, test_case_ccon_vpd_list))
        # print("Time for VPD: ", time() - start_time)
        # start_time = time()
        # _safe_call(_LIB.POF_estimate_ccon_grad(self._nndata, self._user_parameters, self._indata, ccon_est_list, ccon_est_square_list, self.n_sample))
        # _safe_call(_LIB.POF_estimate_test_case_ccon_grad(self._nndata, self._user_parameters, self._indata, test_case_ccon_est_list, test_case_ccon_est_square_list, self.n_sample))
        # _safe_call(_LIB.POF_estimate_grad(self._nndata, self._user_parameters, self._indata, ccon_est_list, ccon_est_square_list, test_case_ccon_est_list, test_case_ccon_est_square_list, self.n_sample))
        # print("Time for estimation: ", time() - start_time)

        # start_time = time()
        # _safe_call(_LIB.POF_fullcalc_approx_grad_test_case_ccon(self._nndata, self._user_parameters, self._indata, test_case_ccon_gradient_approx_list))
        # print("Time for fullcalc_approx: ", time() - start_time)

        ccon_gradient = as_array(ccon_gradient_list, (ccon_w_len, ))
        ccon_gradient = torch.Tensor(ccon_gradient).to(self.ccon_w.device)
        ccon_gradient = torch.reshape(ccon_gradient, self.ccon_w.shape)
        self.ccon_grad = ccon_gradient

        ccon_est = as_array(ccon_est_list, (ccon_w_len, ))
        ccon_est = torch.Tensor(ccon_est).to(self.ccon_w.device)
        ccon_est = torch.reshape(ccon_est, self.ccon_w.shape)
        ccon_est_square = as_array(ccon_est_square_list, (ccon_w_len, ))
        ccon_est_square = torch.Tensor(ccon_est_square).to(self.ccon_w.device)
        ccon_est_square = torch.reshape(ccon_est_square, self.ccon_w.shape)
        ccon_est_std = torch.sqrt(ccon_est_square - torch.pow(ccon_est, 2)).mean()

        test_case_ccon_gradient = as_array(test_case_ccon_gradient_list, (test_case_ccon_w_len, ))
        test_case_ccon_gradient = torch.Tensor(test_case_ccon_gradient).to(self.test_case_ccon_w.device)
        test_case_ccon_gradient = torch.reshape(test_case_ccon_gradient, self.test_case_ccon_w.shape)
        self.test_case_ccon_grad = test_case_ccon_gradient
        print("test_case_ccon_gradient shape: ", test_case_ccon_gradient.shape)

        if torch.isnan(test_case_ccon_gradient).any():
            print("Warning: Found NaN in test_case_ccon_gradient.")
            print("NN data: ")
            _LIB.POF_print_nndata(self._nndata)
            print("User parameters: ")
            _LIB.POF_print_userparameter(self._user_parameters)
            print("Indata: ")
            _LIB.POF_print_indata(self._indata)

            test_case_ccon_gradient = torch.nan_to_num(test_case_ccon_gradient, nan=0.0, posinf=1.0, neginf=-1.0)
            # raise ValueError("test_case_ccon_gradient contains NaN values")

        # For plotting grads
        self.test_case_ccon_grad_list.append(test_case_ccon_gradient[:, 0].squeeze())
        self.test_case_ccon_grad_list = self.test_case_ccon_grad_list[-10:]

        # test_case_ccon_vpd = as_array(test_case_ccon_vpd_list, (test_case_ccon_w_len, ))
        # test_case_ccon_vpd = torch.Tensor(test_case_ccon_vpd).to(self.test_case_ccon_w.device)
        # test_case_ccon_vpd = torch.reshape(test_case_ccon_vpd, self.test_case_ccon_w.shape)

        # test_case_ccon_est = as_array(test_case_ccon_est_list, (test_case_ccon_w_len, ))
        # test_case_ccon_est = torch.Tensor(test_case_ccon_est).to(self.test_case_ccon_w.device)
        # test_case_ccon_est = torch.reshape(test_case_ccon_est, self.test_case_ccon_w.shape)
        # test_case_ccon_est_square = as_array(test_case_ccon_est_square_list, (test_case_ccon_w_len, ))
        # test_case_ccon_est_square = torch.Tensor(test_case_ccon_est_square).to(self.test_case_ccon_w.device)
        # test_case_ccon_est_square = torch.reshape(test_case_ccon_est_square, self.test_case_ccon_w.shape)
        # test_case_ccon_est_std = torch.sqrt(test_case_ccon_est_square - torch.pow(test_case_ccon_est, 2)).mean()

        # test_case_ccon_gradient_approx = as_array(test_case_ccon_gradient_approx_list, (test_case_ccon_w_len, ))
        # test_case_ccon_gradient_approx = torch.Tensor(test_case_ccon_gradient_approx).to(self.test_case_ccon_w.device)
        # test_case_ccon_gradient_approx = torch.reshape(test_case_ccon_gradient_approx, self.test_case_ccon_w.shape)
        
        # print("test_case_ccon_gradient: ", test_case_ccon_gradient)
        # print("test_case_ccon_gradient_approx: ", test_case_ccon_gradient_approx)

        # print("ccon_gradient: ", ccon_gradient)
        # print("ccon_est: ", ccon_est)
        # print("test_case_ccon_gradient: ", test_case_ccon_gradient)
        # print("test_case_ccon_est: ", test_case_ccon_est)
        # print(self.ans)

        ce_loss = nn.CrossEntropyLoss().to(self.test_case_ccon_w.device)(self.test_case_ccon_w.transpose(0, 1)[0], 1. - self.ans)

        output = torch.cat((self.ccon_w, self.test_case_ccon_w), 0)
        norm_regularizer = torch.mean(torch.linalg.vector_norm(output, dim=-1, ord=2))

        # if torch.isnan(test_case_ccon_gradient_std):
        #     print("test_case_ccon_gradient_std is nan")
        #     print(test_case_ccon_est_square)
        #     print(torch.pow(test_case_ccon_gradient, 2))
        #     print(test_case_ccon_est_square-torch.pow(test_case_ccon_gradient, 2))

        test_case_ccon_gradient_true_class = torch.where(self.ans==0, test_case_ccon_gradient.transpose(0, 1)[0], 0.).mean()
        test_case_ccon_gradient_false_class = torch.where(self.ans==1, test_case_ccon_gradient.transpose(0, 1)[0], 0.).mean()

        # test_case_ccon_vpd_true_class = torch.where(self.ans==0, test_case_ccon_gradient.transpose(0, 1)[0], 0.).mean()
        # test_case_ccon_vpd_false_class = torch.where(self.ans==1, test_case_ccon_gradient.transpose(0, 1)[0], 0.).mean()

        # test_case_ccon_est_true_class = torch.where(self.ans==0, test_case_ccon_est.transpose(0, 1)[0], 0.).mean()
        # test_case_ccon_est_false_class = torch.where(self.ans==1, test_case_ccon_est.transpose(0, 1)[0], 0.).mean()

        # vpd_error_ratio = torch.norm(test_case_ccon_vpd.transpose(0, 1)[0] - test_case_ccon_gradient.transpose(0, 1)[0], p=2) / torch.norm(test_case_ccon_gradient.transpose(0, 1)[0], p=2)
        # est_error_ratio = torch.norm(test_case_ccon_est.transpose(0, 1)[0] - test_case_ccon_gradient.transpose(0, 1)[0], p=2) / torch.norm(test_case_ccon_gradient.transpose(0, 1)[0], p=2)
        # approx_error = torch.norm(test_case_ccon_gradient_approx.transpose(0, 1)[0] - test_case_ccon_gradient.transpose(0, 1)[0], p=2)
        # approx_error_ratio = approx_error / torch.norm(test_case_ccon_gradient.transpose(0, 1)[0], p=2)

        safe_answer_rate = torch.where(self.ccon_w.transpose(0, 1)[0]<0, 1., 0.).mean()

        med_margin = torch.median(torch.abs(output.transpose(0, 1)[0]))
        min_margin = torch.min(torch.abs(output.transpose(0, 1)[0]))
        max_margin = torch.max(torch.abs(output.transpose(0, 1)[0]))

        postprob_0_0 = _LIB.POF_get_posterior(self._indata, 0, 0, 0)
        postprob_1_0 = _LIB.POF_get_posterior(self._indata, 0, 1, 0)
        print(f"P[0|0]: {postprob_0_0}, P[0|1]: {postprob_1_0}")
        _LIB.POF_print_indata(self._indata)

        return {
            'train_loss': torch.sum(ccon_gradient * self.ccon_w), 
            'intest_loss': torch.sum(test_case_ccon_gradient * self.test_case_ccon_w),
            # 'intest_vpd_loss': torch.sum(test_case_ccon_vpd * self.test_case_ccon_w),
            # 'train_est_loss': torch.sum(ccon_est * self.ccon_w),
            # 'intest_est_loss': torch.sum(test_case_ccon_est * self.test_case_ccon_w),
            'norm_regularizer': norm_regularizer,
            'ce_loss': ce_loss,
            'approx_loss': self.approx_loss,
            'intest_acc': self.intest_acc,
            'intest_FPR': self.intest_FPR,
            'intest_FNR': self.intest_FNR,
            'True_class_intest_grad': test_case_ccon_gradient_true_class,
            'False_class_intest_grad': test_case_ccon_gradient_false_class,
            # 'True_class_intest_vpd': test_case_ccon_vpd_true_class,
            # 'False_class_intest_vpd': test_case_ccon_vpd_false_class,
            # 'True_class_intest_est': test_case_ccon_est_true_class,
            # 'False_class_intest_est': test_case_ccon_est_false_class,
            # 'vpd_error_ratio': vpd_error_ratio,
            # 'est_error_ratio': est_error_ratio,
            # 'approx_error_ratio': approx_error_ratio,
            # 'approx_error': approx_error,
            # 'train_est_std': ccon_est_std,
            # 'intest_est_std': test_case_ccon_est_std,
            'safe_answer_rate': safe_answer_rate,
            'med_margin': med_margin,
            'min_margin': min_margin,
            'max_margin': max_margin,
            'P[0|0]': postprob_0_0,
            'P[0|1]': postprob_1_0,
        }
    
    def backward(self, loss: torch.Tensor):
        self.module.backward(loss)
        # if hasattr(self.module, 'backward'):
        #     self.module.backward(loss)
        # elif self.optimizer is not None:
        #     loss.backward()
        # else:
        #     raise AssertionError("Module has no attribute 'backward' and optimizer is None")
    
    def step(self):
        # if hasattr(self.module, 'module'):
        #     # Apply gradient clipping
        #     total_norm = 0.0
        #     for param in self.module.module.parameters():
        #         if param.grad is not None:
        #             if torch.isnan(param.grad).any():
        #                 print("Warning: Found NaN in gradients.")
        #                 param.grad.data = torch.nan_to_num(param.grad.data, nan=0.0, posinf=1.0, neginf=-1.0)
        #             param_norm = param.grad.data.norm(2)
        #             total_norm += param_norm.item() ** 2
        #     total_norm = total_norm ** 0.5
        #     print(f"Total gradient norm: {total_norm}")

        #     # Gradient clipping
        #     clip_coef = 1.0
        #     if hasattr(self.module, 'args') and hasattr(self.module.args, 'gradient_clipping'):
        #         if total_norm > self.module.args.gradient_clipping:
        #             clip_coef = self.module.args.gradient_clipping / (total_norm + 1e-6)
        #             for param in self.module.module.parameters():
        #                 if param.grad is not None:
        #                     param.grad.data.mul_(clip_coef)
        #             print(f"Clipped gradients with coef: {clip_coef}")

        #     # Add small epsilon if gradient is too small
        #     if total_norm < 1e-8:
        #         print(f"Warning: Small gradient norm detected ({total_norm}), adding epsilon")
        #         for param in self.module.module.parameters():
        #             if param.grad is not None:
        #                 # Add small noise to prevent exact zero
        #                 noise = torch.randn_like(param.grad) * 1e-8
        #                 param.grad.data.add_(noise)
        # else:
        #     print("Warning: self.module has no attribute 'module'")

        if hasattr(self.module, 'step'):
            self.module.step()
        elif self.optimizer is not None:
            self.optimizer.step()
        else:
            raise AssertionError("Module has no attribute 'step' and optimizer is None")

    def train(self):
        assert hasattr(self.module, 'train'), "Module has no attribute 'train'"
        self.module.train()

    def eval(self):
        assert hasattr(self.module, 'eval'), "Module has no attribute 'eval'"
        self.module.eval()

    def save_checkpoint(self, output_dir, tag: Optional = None):
        assert hasattr(self.module, 'save_checkpoint'), "Module has no attribute 'save_checkpoint'"
        self.module.save_checkpoint(output_dir, tag)

    def gradient_checkpointing_enable(self):
        try:
            self.module.gradient_checkpointing_enable()
        except:
            self.module.module.gradient_checkpointing_enable()

    def gradient_checkpointing_disable(self):
        try:
            self.module.gradient_checkpointing_disable()
        except:
            self.module.module.gradient_checkpointing_disable()

    def __call__(self, *args, **kwargs):
        if self.forward_logit:
            assert not self.module.module.training, "Module is in training mode"
            return self.get_logits(*args, **kwargs)
        else:
            if self.module.module.training:
                return self.train_forward(*args, **kwargs)
            else:
                return self.inference_forward(*args, **kwargs)

    def plot_logits(self, save_path: str):
        """
        Plot the second logits (probability of being safe) for all test cases using matplotlib.
        Blue points represent safe samples (ans=1) and red points represent unsafe samples (ans=0).
        Safe samples are plotted on the left side and unsafe samples on the right.
        """
        import matplotlib.pyplot as plt

        # Extract the logits and safety labels
        logits = self.test_case_ccon_w[:, 0].cpu().numpy()  # Get first logit (unsafe probability)
        labels = self.ans.cpu().numpy()

        bias = _LIB.POF_get_bias(self._nndata)
        postprob_1_0 = _LIB.POF_get_posterior(self._indata, 0, 1, 0)
        print("Bias: ", bias)
        print("P[1|0]: ", postprob_1_0)

        # Separate safe and unsafe samples
        safe_indices = np.where(labels == 1)[0]
        unsafe_indices = np.where(labels == 0)[0]

        # Create figure and axis
        plt.figure(figsize=(10, 6))

        # Plot safe samples (blue)
        if len(safe_indices) > 0:
            plt.scatter(
                range(len(safe_indices)), 
                logits[safe_indices],
                color='blue',
                label='Safe (ans=1)',
                alpha=0.6
            )

        # Plot unsafe samples (red)
        if len(unsafe_indices) > 0:
            plt.scatter(
                range(len(safe_indices), len(logits)), 
                logits[unsafe_indices],
                color='red',
                label='Unsafe (ans=0)',
                alpha=0.6
            )

        plt.xlabel('Sample Index')
        plt.ylabel('Unsafe Logit')
        plt.title('Test Case Safety Logits')
        # plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)

        # Add a horizontal line at y=0 for reference
        plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)
        plt.axhline(y=-bias, color='red', linestyle='-', label='Bias')

        # Add text with statistics
        total_answers = len(logits)
        safe_percentage = (len(safe_indices) / total_answers) * 100
        unsafe_percentage = (len(unsafe_indices) / total_answers) * 100

        under_bias_percentage = len(np.where(logits < bias)[0]) / total_answers * 100
    
        stats_text = f'Total candidates: {total_answers}\n'
        stats_text += f'Safe: {len(safe_indices)} ({safe_percentage:.1f}%)\n'
        stats_text += f'Unsafe: {len(unsafe_indices)} ({unsafe_percentage:.1f}%)'
        stats_text += f'\nUnder bias: {under_bias_percentage:.1f}%'
        stats_text += f'\nBias: {bias:.2f}'
        stats_text += f'\nP[1|0]: {postprob_1_0:.2f}'
        
        plt.text(0.02, 0.98, stats_text,
                 transform=plt.gca().transAxes,
                 verticalalignment='top',
                 bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
        plt.tight_layout()

        # Save the plot
        plt.savefig(save_path)
        plt.close()  # Close the figure to free memory

    def plot_grads(self, save_path: str):
        """
        Plot the second logits (probability of being safe) for all test cases using matplotlib.
        Blue points represent safe samples (ans=1) and red points represent unsafe samples (ans=0).
        Safe samples are plotted on the left side and unsafe samples on the right.
        """
        import matplotlib.pyplot as plt

        # Extract the grads and safety labels
        grads = torch.concat(self.test_case_ccon_grad_list, dim=0).cpu().numpy()  # Get first logit (unsafe probability)
        labels = torch.concat(self.ans_list, dim=0).cpu().numpy()

        print("grads shape: ", grads.shape)
        print("labels shape: ", labels.shape)

        # Separate safe and unsafe samples
        safe_indices = np.where(labels == 1)[0]
        unsafe_indices = np.where(labels == 0)[0]

        # Create figure and axis
        plt.figure(figsize=(10, 6))

        # Plot safe samples (blue)
        if len(safe_indices) > 0:
            plt.scatter(
                range(len(safe_indices)), 
                grads[safe_indices],
                color='blue',
                label='Safe (ans=1)',
                alpha=0.6
            )

        # Plot unsafe samples (red)
        if len(unsafe_indices) > 0:
            plt.scatter(
                range(len(safe_indices), len(grads)), 
                grads[unsafe_indices],
                color='red',
                label='Unsafe (ans=0)',
                alpha=0.6
            )

        plt.xlabel('Sample Index')
        plt.ylabel('Unsafe Logit')
        plt.title('Test Case Safety grads')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)

        # Add a horizontal line at y=0 for reference
        plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)

        plt.tight_layout()

        # Save the plot
        plt.savefig(save_path)
        plt.close()  # Close the figure to free memory
